#pragma once

#include "global_variable.hpp"
#include <cstdint>
#include <cstring>
#include <map>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <stdexcept>
#include <sys/types.h>

namespace py = pybind11;
// SOME python types
inline auto PyType_NdArray = py::module_::import("numpy").attr("ndarray");
inline auto PyType_Integer = py::int_().get_type();
inline auto PyType_Float = py::float_().get_type();
inline auto PyType_Bool = py::bool_().get_type();
inline auto PyType_List = py::list().get_type();
inline auto PyType_Str = py::str().get_type();

struct BuffProto {
  ssize_t itemsize;
  ssize_t size;
  char format[4]; // it should be char[2], but the compiler will try to align
                  // it. Just set it to 4 bytes at the beginning
  bool readonly = false;
  ssize_t ndims;
  // shape
  ssize_t *shape() { return &this->ndims + 1; }
  // strides
  ssize_t *strides() { return this->shape() + this->ndims; }
  // data pointer
  void *ptr() { return static_cast<void *>(this->strides() + this->ndims); }
};

enum class PyDataType {
  Unsupported = -1,
  PyObj = 0,
  PyInteger = 1,
  PyFloat = 2,
  PyBool = 3,
  PyStr = 4,
  PyList = 10,
  PyDict = 11,
  NdArray = 20,
};

inline PyDataType get_pyobj_dtype(py::object const &obj) {
  if (py::isinstance(obj, PyType_NdArray)) {
    return PyDataType::NdArray;
  } else if (py::isinstance(obj, PyType_Str)) {
    return PyDataType::PyStr;
  } else if (py::isinstance(obj, PyType_Integer)) {
    return PyDataType::PyInteger;
  } else if (py::isinstance(obj, PyType_Float)) {
    return PyDataType::PyFloat;
  } else if (py::isinstance(obj, PyType_Bool)) {
    return PyDataType::PyBool;
  } else if (py::isinstance(obj, PyType_List)) {
    return PyDataType::PyList;
  } else {
    return PyDataType::Unsupported;
  }
}

inline size_t get_pyobj_size(py::object const &obj) {
  if (py::isinstance(obj, PyType_NdArray)) {
    ssize_t i = 0, nbytes = 1;
    auto buff = py::cast<py::buffer>(obj);
    auto info = buff.request();
    for (; i < info.ndim; i++) {
      if (info.shape[i] <= 0) {
        throw py::value_error("ndarray shape must > 0");
      }
      nbytes *= info.shape[i];
    }
    return nbytes * info.itemsize + sizeof(ssize_t) * 5 + 8;
  } else if (py::isinstance(obj, PyType_Str)) {
    auto str = py::cast<std::string_view>(obj);
    return sizeof(size_t) + str.size();
  } else if (py::isinstance(obj, PyType_Integer)) {
    return sizeof(int64_t);
  } else if (py::isinstance(obj, PyType_Integer)) {
    return sizeof(double);
  } else if (py::isinstance(obj, PyType_Float)) {
    return sizeof(bool);
  } else if (py::isinstance(obj, PyType_List)) {
    // TODO:
    return 0;
  } else {
    return 0;
  }
}

inline void get_pyobj_data(py::object const &obj, PyDataType dtype, void *dest,
                           size_t size) {
  if (dtype == PyDataType::PyFloat) {
    double _tmp = py::cast<double>(obj);
    memcpy(dest, &_tmp, size);
  } else if (dtype == PyDataType::PyBool) {
    bool _tmp = py::cast<bool>(obj);
    memcpy(dest, &_tmp, size);
  } else if (dtype == PyDataType::PyInteger) {
    int64_t _tmp = py::cast<int64_t>(obj);
    memcpy(dest, &_tmp, size);
  } else if (dtype == PyDataType::PyStr) {
    auto _tmp = py::cast<std::string_view>(obj);
    *static_cast<size_t *>(dest) = _tmp.size();
    memcpy(static_cast<char *>(dest) + sizeof(size_t), _tmp.data(),
           _tmp.size());
  } else if (dtype == PyDataType::NdArray) {
    auto buff = py::cast<py::buffer>(obj);
    auto info = buff.request();
    auto bp = static_cast<BuffProto *>(dest);
    bp->itemsize = info.itemsize;
    bp->size = info.size;
    strncpy(bp->format, info.format.c_str(), 3);
    bp->ndims = info.ndim;
    memcpy(bp->shape(), info.shape.data(), bp->ndims * sizeof(ssize_t));
    memcpy(bp->strides(), info.strides.data(), bp->ndims * sizeof(ssize_t));
    bp->readonly = info.readonly;
    memcpy(bp->ptr(), info.ptr, info.size * info.itemsize);
  }
}
